{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kSvQC3zGDAcn"
   },
   "source": [
    "##### Copyright 2020 The TensorFlow Probability Authors.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "b1a_UlukDhK-"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "j1tv9qSiCY0i"
   },
   "source": [
    "# Probabilistic Programming in Oryx\n",
    "\n",
    "\n",
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://www.tensorflow.org/probability/oryx/examples/probabilistic_programming\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/probability/blob/master/spinoffs/oryx/examples/notebooks/probabilistic_programming.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/tensorflow/probability/blob/master/spinoffs/oryx/examples/notebooks/probabilistic_programming.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/probability/spinoffs/oryx/examples/notebooks/probabilistic_programming.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "ZKnHamj7mOCw"
   },
   "outputs": [],
   "source": [
    "!pip install -U jax jaxlib\n",
    "!pip install -Uq oryx -I\n",
    "!pip install tfp-nightly --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "arhBvhejmcnS"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style='white')\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import jit, vmap, grad\n",
    "from jax import random\n",
    "\n",
    "from tensorflow_probability.substrates import jax as tfp\n",
    "tfd = tfp.distributions\n",
    "\n",
    "import oryx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SEpFsLcBhJCy"
   },
   "source": [
    "Probabilistic programming is the idea that we can express probabilistic models using features from a programming language. Tasks like Bayesian inference or marginalization are then provided as language features and can potentially be automated.\n",
    "\n",
    "Oryx provides a probabilistic programming system in which probabilistic programs are just expressed as Python functions; these programs are then transformed via composable function transformations like those in JAX! The idea is to start with simple programs (like sampling from a random normal) and compose them together to form models (like a Bayesian neural network). An important point of Oryx's PPL design is to enable programs to look like functions you'd already write and use in JAX, but are *annotated* to make transformations aware of them."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tv08TkyTvOC9"
   },
   "source": [
    "Let's first import Oryx's core PPL functionality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "7_16c5HBp4F6"
   },
   "outputs": [],
   "source": [
    "from oryx.core.ppl import random_variable\n",
    "from oryx.core.ppl import log_prob\n",
    "from oryx.core.ppl import joint_sample\n",
    "from oryx.core.ppl import joint_log_prob\n",
    "from oryx.core.ppl import block\n",
    "from oryx.core.ppl import intervene\n",
    "from oryx.core.ppl import conditional\n",
    "from oryx.core.ppl import graph_replace\n",
    "from oryx.core.ppl import nest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wHbTrVRkuyab"
   },
   "source": [
    "## What are probabilistic programs in Oryx?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G38YnjrGT6Ub"
   },
   "source": [
    "In Oryx, probabilistic programs are just pure Python functions that operate on JAX values and pseudorandom keys and return a random sample. By design, they are compatible with transformations like `jit` and `vmap`. However, the Oryx probabilistic programming system provides tools that enable you to *annotate* your functions in useful ways."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3KJEnMF8UoLj"
   },
   "source": [
    "Following the JAX philosophy of pure functions, an Oryx probabilistic program is a Python function that takes a JAX `PRNGKey` as its first argument and any number of subsequent conditioning arguments. The output of the function is called a \"sample\" and the same restrictions that apply to `jit`-ed and `vmap`-ed functions apply to probabilistic programs (e.g. no data-dependent control flow, no side effects, etc.). This differs from many imperative probabilistic programming systems in which a 'sample' is the entire execution trace, including values internal to the program's execution. We will see later how Oryx can access internal values using the `joint_sample`, discussed below.\n",
    "\n",
    "```\n",
    "Program :: PRNGKey -> ... -> Sample\n",
    "```\n",
    "Here is a \"hello world\" program that samples from a [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Fhls0XS3u225"
   },
   "outputs": [],
   "source": [
    "def log_normal(key):\n",
    "  return jnp.exp(random_variable(tfd.Normal(0., 1.))(key))\n",
    "  \n",
    "print(log_normal(random.PRNGKey(0)))\n",
    "sns.distplot(jit(vmap(log_normal))(random.split(random.PRNGKey(0), 10000)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xNiaDJeC0zyo"
   },
   "source": [
    "The `log_normal` function is a thin wrapper around a [Tensorflow Probability (TFP)](https://www.tensorflow.org/probability) distribution, but instead of calling `tfd.Normal(0., 1.).sample`, we've used `random_variable` instead. As we'll see later, `random_variable` enables us to convert objects into probabilistic programs, along with other useful functionality."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6YK95c-PvEJl"
   },
   "source": [
    "We can convert `log_normal` into a log-density function using the `log_prob` transformation: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "07HpsrfUwrEm"
   },
   "outputs": [],
   "source": [
    "print(log_prob(log_normal)(1.))\n",
    "x = jnp.linspace(0., 5., 1000)\n",
    "plt.plot(x, jnp.exp(vmap(log_prob(log_normal))(x)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MceVcz6npiQK"
   },
   "source": [
    "Because we've annotated the function with `random_variable`, `log_prob` is aware that there was a call to `tfd.Normal(0., 1.).sample` and uses `tfd.Normal(0., 1.).log_prob` to compute the base distribution log prob. To handle the `jnp.exp`, `ppl.log_prob` automatically computes densities through bijective functions, keeping track of volume changes in the change-of-variable computation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VbMBr2rEuuen"
   },
   "source": [
    "In Oryx, we can take programs and transform them using function transformations -- for example, `jax.jit` or `log_prob`. Oryx can't do this with just any program though; it requires sampling functions that have registered their log density function with Oryx.\n",
    "Fortunately, Oryx automatically registers [TensorFlow Probability](https://www.tensorflow.org/probability) (TFP) distributions in its system."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VtwFoenmKHMi"
   },
   "source": [
    "## Oryx's probabilistic programming tools\n",
    "\n",
    "Oryx has several function transformations geared towards probabilistic programming. We'll go over most of them and provide some examples. At the end, we'll put it all together into an MCMC case study. You can also refer to the documentation for `core.ppl.transformations` for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3OqRM-76qZil"
   },
   "source": [
    "### `random_variable`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FlCjr8CUp5Ai"
   },
   "source": [
    "`random_variable` has two main pieces of functionality, both focused on annotating Python functions with information that can be used in transformations.\n",
    "\n",
    "1. `random_variable`' operates as the identity function by default, but can use type-specific registrations to convert objects into probabilistic programs.`\n",
    "\n",
    "  For callable types (Python functions, lambdas, `functools.partial`s, etc.) and arbitrary `object`s (like JAX `DeviceArray`s) it will just return its input.\n",
    "\n",
    "  ```python\n",
    "  random_variable(x: object) == x\n",
    "  random_variable(f: Callable[...]) == f\n",
    "  ```\n",
    "  Oryx automatically registers [TensorFlow Probability (TFP)](https://www.tensorflow.org/probability) distributions, which are converted into probabilistic programs that call the distribution's `sample` method.\n",
    "  \n",
    "  ```python\n",
    "  random_variable(tfd.Normal(0., 1.))(random.PRNGKey(0)) # ==> -0.20584235\n",
    "  ```\n",
    "  Oryx additionally embeds information about the TFP distribution into JAX traces that enables automatically computing log densities.\n",
    "2. `random_variable` can *tag* values with names, making them useful for downstream transformations, by providing an optional `name` keyword argument to `random_variable`. When we pass an array into `random_variable` along with a `name` (e.g. `random_variable(x, name='x')`), it just tags the value and returns it. If we pass in a callable or TFP distribution, `random_variable` returns a program that tags its output sample with `name`.\n",
    "\n",
    "These annotations do not change the *semantics* of the program when executed, but only when transformed (i.e. the program will return the same value with or without the use of `random_variable`). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y1Ei3r0Lz1SY"
   },
   "source": [
    "Let's go over an example where we use both pieces of functionality together. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "FF2G3bDSz6xM"
   },
   "outputs": [],
   "source": [
    "def latent_normal(key):\n",
    "  z_key, x_key = random.split(key)\n",
    "  z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)\n",
    "  return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TCjBPSRD2_cl"
   },
   "source": [
    "In this program we've tagged the intermediates `z` and `x`, which makes the transformations `joint_sample`, `intervene`, `conditional` and `graph_replace` aware of the names `'z'` and `'x'`. We'll go over exactly how each transformation uses names later."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o1TDjshTpX3E"
   },
   "source": [
    "### `log_prob`\n",
    "\n",
    "The `log_prob` function transformation converts an Oryx probabilistic program into its log-density function. This log-density function takes a potential sample from the program as input and returns its log-density under the underlying sampling distribution.\n",
    "\n",
    "```python\n",
    "log_prob :: Program -> (Sample -> LogDensity)\n",
    "```\n",
    "\n",
    "Like `random_variable`, it works via a registry of types where TFP distributions are automatically registered, so `log_prob(tfd.Normal(0., 1.))` calls `tfd.Normal(0., 1.).log_prob`. For Python functions, however, `log_prob` traces the program using JAX and looks for sampling statements.\n",
    "The `log_prob` transformation works on most programs that return random variables, directly or via invertible transformations but not on programs that sample values internally that aren't returned. If it cannot invert the necessary operations in the program, `log_prob` will throw an error.\n",
    "\n",
    "Here are some examples of `log_prob` applied to various programs.\n",
    "\n",
    "1. `log_prob` works on programs that directly sample from TFP distributions (or other registered types) and return their values.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "C4vOzXrB-4Fa"
   },
   "outputs": [],
   "source": [
    "def normal(key):\n",
    "  return random_variable(tfd.Normal(0., 1.))(key)\n",
    "print(log_prob(normal)(0.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vgJ-qhwr_A5u"
   },
   "source": [
    "2. `log_prob` is able to compute log-densities of samples from programs that transform random variates using bijective functions (e.g `jnp.exp`, `jnp.tanh`, `jnp.split`).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "BjudBvQO_aSP"
   },
   "outputs": [],
   "source": [
    "def log_normal(key):\n",
    "  return 2 * jnp.exp(random_variable(tfd.Normal(0., 1.))(key))\n",
    "print(log_prob(log_normal)(1.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4HlyVSOrGeuN"
   },
   "source": [
    "  In order to compute a sample from `log_normal`'s log-density, we first need to invert the `exp`, taking the `log` of the sample, and then add a volume-change correction using the inverse log-det Jacobian of `exp` (see the [change of variable](https://en.wikipedia.org/wiki/Probability_density_function#Function_of_random_variables_and_change_of_variables_in_the_probability_density_function) formula from Wikipedia)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LEpqed8u_l8l"
   },
   "source": [
    "3. `log_prob` works with programs that output structures of samples like, Python dictionaries or tuples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "snC1Ax9SzGSR"
   },
   "outputs": [],
   "source": [
    "def normal_2d(key):\n",
    "  x = random_variable(\n",
    "    tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)))(key)\n",
    "  x1, x2 = jnp.split(x, 2, 0)\n",
    "  return dict(x1=x1, x2=x2)\n",
    "sample = normal_2d(random.PRNGKey(0))\n",
    "print(sample)\n",
    "print(log_prob(normal_2d)(sample))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VGPSYvCx-3GQ"
   },
   "source": [
    "4. `log_prob` walks the traced computation graph of the function, computing both forward and inverse values (and their log-det Jacobians) when necessary in an attempt to connect returned values with their base sampled values via a well-defined change of variables. Take the following example program:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "gKY3xWoz25kc"
   },
   "outputs": [],
   "source": [
    "def complex_program(key):\n",
    "  k1, k2 = random.split(key)\n",
    "  z = random_variable(tfd.Normal(0., 1.))(k1)\n",
    "  x = random_variable(tfd.Normal(jax.nn.relu(z), 1.))(k2)\n",
    "  return jnp.exp(z), jax.nn.sigmoid(x)\n",
    "sample = complex_program(random.PRNGKey(0))\n",
    "print(sample)\n",
    "print(log_prob(complex_program)(sample))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xl2eMLLRImoZ"
   },
   "source": [
    "In this program, we sample `x` conditionally on `z`, meaning we need the value of `z` before we can compute the log-density of `x`. However, in order to compute `z`, we first have to invert the `jnp.exp` applied to `z`. Thus, in order to compute the log-densities of `x` and `z`, `log_prob` needs to first invert the first output, and then pass it forward through the `jax.nn.relu` to compute the mean of `p(x | z)`.\n",
    "\n",
    "Note: while `log_prob` will not work on programs that sample random variables that are not returned as outputs, `joint_sample` can convert a program into one that does."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zfM7TKvz3ZBK"
   },
   "source": [
    "For more information about `log_prob`, you can refer to `core.interpreters.log_prob`. In implementation, `log_prob` is closely based off of the `inverse` JAX transformation; to learn more about `inverse`, see `core.interpreters.inverse`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-T8WvddQJXgq"
   },
   "source": [
    "### `joint_sample`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kwQWbDzLJY-i"
   },
   "source": [
    "To define more complex and interesting programs, we'll use some latent random variables, i.e. random variables with unobserved values. Let's refer to the `latent_normal` program that samples a random value `z` that is used as the mean of another random value `x`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "sf6MbsXKK0ev"
   },
   "outputs": [],
   "source": [
    "def latent_normal(key):\n",
    "  z_key, x_key = random.split(key)\n",
    "  z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)\n",
    "  return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MePD5Qj-Kz4t"
   },
   "source": [
    "In this program, `z` is latent so if we were to just call `latent_normal(random.PRNGKey(0))` we would not know the actual value of `z` that is responsible for generating `x`.\n",
    "\n",
    "`joint_sample` is a transformation that transforms a program into another program that returns a dictionary mapping string names (tags) to their values.\n",
    "In order to work, we need to make sure we tag the latent variables to ensure they appear in the transformed function's output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "SbGeWdn4K7WX"
   },
   "outputs": [],
   "source": [
    "joint_sample(latent_normal)(random.PRNGKey(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-E9P8ZEKLBuh"
   },
   "source": [
    "Note that `joint_sample` transforms a program into another program that samples the joint distribution over its latent values, so we can further transform it. For algorithms like MCMC and VI, it's common to compute the log probability of the joint distribution as part of the inference procedure. `log_prob(latent_normal)` doesn't work because it requires marginalizing out `z`, but we can use `log_prob(joint_sample(latent_normal))`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "-c-Ke3YKLj-f"
   },
   "outputs": [],
   "source": [
    "print(log_prob(joint_sample(latent_normal))(dict(x=0., z=1.)))\n",
    "print(log_prob(joint_sample(latent_normal))(dict(x=0., z=-10.)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mbf-Rl9XL03x"
   },
   "source": [
    "Because this is such a common pattern, Oryx also has a `joint_log_prob` transformation which is just the composition of `log_prob` and `joint_sample`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "mm28T0o6My7-"
   },
   "outputs": [],
   "source": [
    "print(joint_log_prob(latent_normal)(dict(x=0., z=1.)))\n",
    "print(joint_log_prob(latent_normal)(dict(x=0., z=-10.)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aWy_f6FDNTIf"
   },
   "source": [
    "### `block`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Sx7vuRFPgqWC"
   },
   "source": [
    "The `block` transformation takes in a program and a sequence of names and returns a program that behaves identically except that in downstream transformations (like `joint_sample`), the provided names are ignored.\n",
    "An example of where `block` is handy is converting a joint distribution into a prior over the latent variables by \"blocking\" the values sampled in the likelihood.\n",
    "For example, take `latent_normal`, which first draws a `z ~ N(0, 1)` then an `x | z ~ N(z, 1e-1)`. `block(latent_normal, names=['x'])` is a program that hides the `x` name, so if we do `joint_sample(block(latent_normal, names=['x']))`, we obtain a dictionary with just `z` in it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "E92Ip1sihD1R"
   },
   "outputs": [],
   "source": [
    "blocked = block(latent_normal, names=['x'])\n",
    "joint_sample(blocked)(random.PRNGKey(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AcuszFsTgoip"
   },
   "source": [
    "### `intervene`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nkvyy1jhNV9-"
   },
   "source": [
    "The `intervene` transformation clobbers samples in a probabilistic program with values from the outside. Going back to our `latent_normal` program, let's say we were interested in running the same program but wanted `z` to be fixed to 4. Rather than writing a new program, we can use `intervene` to override the value of `z`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "6cJbpzb5Npf0"
   },
   "outputs": [],
   "source": [
    "intervened = intervene(latent_normal, z=4.)\n",
    "sns.distplot(vmap(intervened)(random.split(random.PRNGKey(0), 10000)))\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SefNiVNCN6q0"
   },
   "source": [
    "The `intervened` function samples from `p(x | do(z = 4))` which is just a standard normal distribution centered at 4. When we `intervene` on a particular value, that value is *no longer considered a random variable*. This means that a `z` value will not be tagged while executing `intervened`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XKViR31cSr5H"
   },
   "source": [
    "### `conditional`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aSdSyDIXSufY"
   },
   "source": [
    "`conditional` transforms a program that samples latent values into one that conditions on those latent values. Returning to our `latent_normal` program, which samples `p(x)` with a latent `z`, we can convert it into a conditional program `p(x | z)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "qDpjwG98SEe7"
   },
   "outputs": [],
   "source": [
    "cond_program = conditional(latent_normal, 'z')\n",
    "print(cond_program(random.PRNGKey(0), 100.))\n",
    "print(cond_program(random.PRNGKey(0), 50.))\n",
    "sns.distplot(vmap(lambda key: cond_program(key, 1.))(random.split(random.PRNGKey(0), 10000)))\n",
    "sns.distplot(vmap(lambda key: cond_program(key, 2.))(random.split(random.PRNGKey(0), 10000)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_mPGbUsPXXpy"
   },
   "source": [
    "### `nest`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SLlGjRwPXabc"
   },
   "source": [
    "When we start composing probabilistic programs to build more complex ones, it's common to reuse functions that have some important logic. For example, if we'd like to build a Bayesian neural network, there might be an important `dense` program that samples weights and executes a forward pass.\n",
    "\n",
    "If we reuse functions, however, we might end up with duplicate tagged values in the final program, which is disallowed by transformations like `joint_sample`. We can use the `nest` to create tag \"scopes\" where any samples inside of a named scope will be inserted into a nested dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "w6f8eFruXZjX"
   },
   "outputs": [],
   "source": [
    "def f(key):\n",
    "  return random_variable(tfd.Normal(0., 1.), name='x')(key)\n",
    "\n",
    "def g(key):\n",
    "  k1, k2 = random.split(key)\n",
    "  return nest(f, scope='x1')(k1) + nest(f, scope='x2')(k2)\n",
    "joint_sample(g)(random.PRNGKey(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cWcGIi9vY1QH"
   },
   "source": [
    "## Case study: Bayesian neural network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xhTOYwgYKt5S"
   },
   "source": [
    "Let's try our hand at training a Bayesian neural network for classifying the classic [Fisher Iris](https://archive.ics.uci.edu/ml/datasets/iris) dataset. It's relatively small and low-dimensional so we can try directly sampling the posterior with MCMC."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "C2ILC06mK6B5"
   },
   "source": [
    "First, let's import the dataset and some additional utilities from Oryx."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "EaOJyqfN4fmn"
   },
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "iris = datasets.load_iris()\n",
    "features, labels = iris['data'], iris['target']\n",
    "\n",
    "num_features = features.shape[-1]\n",
    "num_classes = len(iris.target_names)\n",
    "\n",
    "from oryx.experimental import mcmc\n",
    "from oryx.util import summary, get_summaries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tow5P0PAK80o"
   },
   "source": [
    "We begin by implementing a dense layer, which will have normal priors over the weights and bias. To do this, we first define a `dense` higher order function that takes in the desired output dimension and activation function. The `dense` function returns a probabilistic program that represents a conditional distribution `p(h | x)` where `h` is the output of a dense layer and `x` is its input. It first samples the weight and bias and then applies them to `x`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "wwY-yGXsY2fT"
   },
   "outputs": [],
   "source": [
    "def dense(dim_out, activation=jax.nn.relu):\n",
    "  def forward(key, x):\n",
    "    dim_in = x.shape[-1]\n",
    "    w_key, b_key = random.split(key)\n",
    "    w = random_variable(\n",
    "          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),\n",
    "          name='w')(w_key)\n",
    "    b = random_variable(\n",
    "          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),\n",
    "          name='b')(b_key)\n",
    "    return activation(jnp.dot(w, x) + b)\n",
    "  return forward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BRqrN-X9LQ3h"
   },
   "source": [
    "To compose several `dense` layers together, we will implement an `mlp` (multilayer perceptron) higher order function which takes in a list of hidden sizes and a number of classes. It returns a program that repeatedly calls `dense` using the appropriate `hidden_size` and finally returns logits for each class in the final layer. Note the use of `nest` which creates name scopes for each layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Hd905GHULQBs"
   },
   "outputs": [],
   "source": [
    "def mlp(hidden_sizes, num_classes):\n",
    "  num_hidden = len(hidden_sizes)\n",
    "  def forward(key, x):\n",
    "    keys = random.split(key, num_hidden + 1)\n",
    "    for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):\n",
    "      x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)\n",
    "    logits = nest(dense(num_classes, activation=lambda x: x),\n",
    "                  scope=f'layer_{num_hidden + 1}')(keys[-1], x)\n",
    "    return logits\n",
    "  return forward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NwzkkEenLobb"
   },
   "source": [
    "To implement the full model, we'll need to model the labels as categorical random variables. We'll define a `predict` function which takes in a dataset of `xs` (the features) which are then passed into an `mlp` using `vmap`. When we use `vmap(partial(mlp, mlp_key))`, we sample a single set of weights, but map the forward pass over all the input `xs`. This produces a set of `logits` which parameterizes independent categorical distributions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "nHqUvZe_Ll1m"
   },
   "outputs": [],
   "source": [
    "def predict(mlp):\n",
    "  def forward(key, xs):\n",
    "    mlp_key, label_key = random.split(key)\n",
    "    logits = vmap(partial(mlp, mlp_key))(xs)\n",
    "    return random_variable(\n",
    "        tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)\n",
    "  return forward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UTgi-vflMAsb"
   },
   "source": [
    "That's the full model! Let's use MCMC to sample the posterior of the BNN weights given data; first we construct a BNN \"template\" using `mlp`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "icKfgHEG2vTo"
   },
   "outputs": [],
   "source": [
    "bnn = mlp([200, 200], num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wBS7Pt8cND5j"
   },
   "source": [
    "To construct a starting point for our Markov chain, we can use `joint_sample` with a dummy input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "lZ36LOC5a1oV"
   },
   "outputs": [],
   "source": [
    "weights = joint_sample(bnn)(random.PRNGKey(0), jnp.ones(num_features))\n",
    "print(weights.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Szb3iHs4Ne5X"
   },
   "source": [
    "Computing the joint distribution log probability is sufficient for many inference algorithms. Let's now say we observe `x` and want to sample the posterior `p(z | x)`. For complex distributions, we won't be able to marginalize out `x` (though for `latent_normal` we can) but we can compute an *unnormalized* log density `log p(z, x)` where `x` is fixed to a particular value. We can use the unnormalized log probability with MCMC to sample the posterior. Let's write this \"pinned\" log prob function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "BGK64cK621QI"
   },
   "outputs": [],
   "source": [
    "def target_log_prob(weights):\n",
    "  return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aCe1smeqNlAv"
   },
   "source": [
    "Now we can use `tfp.mcmc` to sample the posterior using our unnormalized log density function. Note that we'll have to use a \"flattened\" version of our nested weights dictionary to be compatible with `tfp.mcmc`, so we use JAX's tree utilities to flatten and unflatten."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "nmjmxzGhN855"
   },
   "outputs": [],
   "source": [
    "@jit\n",
    "def run_chain(key, weights):\n",
    "  flat_state, sample_tree = jax.tree_flatten(weights)\n",
    "\n",
    "  def flat_log_prob(*states):\n",
    "    return target_log_prob(jax.tree_unflatten(sample_tree, states))\n",
    "\n",
    "  def trace_fn(_, results):\n",
    "    return results.inner_results.accepted_results.target_log_prob\n",
    "\n",
    "  flat_states, log_probs = tfp.mcmc.sample_chain(\n",
    "    1000,\n",
    "    num_burnin_steps=9000,\n",
    "    kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
    "        tfp.mcmc.HamiltonianMonteCarlo(flat_log_prob, 1e-3, 100),\n",
    "        9000, target_accept_prob=0.7),\n",
    "    trace_fn=trace_fn,\n",
    "    current_state=flat_state,\n",
    "    seed=key)\n",
    "  samples = jax.tree_unflatten(sample_tree, flat_states)\n",
    "  return samples, log_probs\n",
    "posterior_weights, log_probs = run_chain(random.PRNGKey(0), weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "WP35ZNG142Kf"
   },
   "outputs": [],
   "source": [
    "plt.plot(log_probs)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xMJzCKL0Vy_x"
   },
   "source": [
    "We can use our samples to take a Bayesian model averaging (BMA) estimate of the training accuracy. To compute it, we can use `intervene` with `bnn` to \"inject\" posterior weights in place of the ones that are sampled from the key. To compute logits for each data point for each posterior sample, we can double `vmap` over `posterior_weights` and `features`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Lh2TM9nPWEdo"
   },
   "outputs": [],
   "source": [
    "output_logits = vmap(lambda weights: vmap(lambda x: intervene(bnn, **weights)(\n",
    "    random.PRNGKey(0), x))(features))(posterior_weights)\n",
    "output_probs = jax.nn.softmax(output_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "HBQu6K36_Nps"
   },
   "outputs": [],
   "source": [
    "print('Average sample accuracy:', (\n",
    "    output_probs.argmax(axis=-1) == labels[None]).mean())\n",
    "print('BMA accuracy:', (\n",
    "    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HQU2STV9LlPx"
   },
   "source": [
    "## Conclusion\n",
    "\n",
    "In Oryx, probabilistic programs are just JAX functions that take in (pseudo-)randomness as an input. Because of Oryx's tight integration with JAX's function transformation system, we can write and manipulate probabilistic programs like we're writing JAX code. This results in a simple but flexible system for building complex models and doing inference."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Probabilistic Programming in Oryx",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
